# Replication materials for "Messy Data, Robust Inference? Navigating Obstancles to Inference with bigKRLS"
# By: Pete Mohanty (pmohanty@stanford.edu) and Robert Shaffer (rbshaffer@utexas.edu)

library(dplyr)
library(stringr)
library(reshape2)
library(ggplot2)

load('~/Dropbox/pa_replication/Appendix_C/Appendix_C.1-2/results_sims_popsd.RData')

# to_save is the output object - this loop extracts the sample AME estimates from each model run
varnames <- c('Mortality', 'Unemployment', 'Rural', 'Age', 'Income', 'Poverty', 'College', 'White')
avgderiv_df <- data.frame()
for(i in 1:length(to_save)){
  N <- to_save[[i]]$N
  
  avgderiv <- to_save[[i]]$avgderiv_sample
  if(to_save[[i]]$Type == 'simple'){
    to_add <- data.frame('N' = N, 'Type' = 'Simple')
  } else{
    to_add <- data.frame('N' = N, 'Type' = 'Complex')
  }
  
  to_add[3:(length(avgderiv)+2)] <- avgderiv
  avgderiv_df <- rbind(avgderiv_df,
                       to_add)
}

# we use the sample AMEs to calculate an estimate of the population SD across repeated samples
names(avgderiv_df) <- c('N', 'Type', varnames)
pop_sd <- avgderiv_df %>% group_by(N, Type) %>% summarise_all(sd)

# create a data frame which tracks the difference between the population SD and the SD estimates for each sample (corrected and uncorrected)
diffs <- data.frame()
for(i in 1:length(to_save)){
  N <- to_save[[i]]$N
  
  if(to_save[[i]]$Type == 'simple'){
    diffs <- rbind(diffs,
                   data.frame(N, 'Type'= 'Corrected', 'DGP' = 'Simple', 
                              to_save[[i]]$se_corr - (pop_sd[pop_sd$N == N & pop_sd$Type == 'Simple', 3:ncol(pop_sd)]), stringsAsFactors = FALSE),
                   data.frame(N, 'Type'= 'Uncorrected', 'DGP' = 'Simple', 
                              to_save[[i]]$se_uncorr - (pop_sd[pop_sd$N == N & pop_sd$Type == 'Simple',3:ncol(pop_sd)]), stringsAsFactors = FALSE))
  } else{
    diffs <- rbind(diffs,
                   data.frame(N, 'Type'= 'Corrected', 'DGP' = 'Complex', 
                              to_save[[i]]$se_corr - (pop_sd[pop_sd$N == N & pop_sd$Type == 'Complex',3:ncol(pop_sd)]), stringsAsFactors = FALSE),
                   data.frame(N, 'Type'= 'Uncorrected', 'DGP' = 'Complex', 
                              to_save[[i]]$se_uncorr - (pop_sd[pop_sd$N == N & pop_sd$Type == 'Complex',3:ncol(pop_sd)]), stringsAsFactors = FALSE))
  }
  
}

# calculate performance - change the function inside summarise_all to alter the performance measure (root mean absolute error right now)
# results pretty much the same by RMSE
to_plot <- diffs %>% group_by(N, Type, DGP) %>% summarise_all(function(x){sqrt(mean(x^2))})

# messy loop that calculates the number of coefficient estimates at each sample size for which our correction is closer to the population SD
for(nval in unique(to_plot$N)){
  cat(nval)
  print(sum(to_plot[to_plot$N == nval & to_plot$Type == 'Corrected' & to_plot$DGP == 'Complex',3:ncol(to_plot)] < 
              to_plot[to_plot$N == nval & to_plot$Type == 'Uncorrected' & to_plot$DGP == 'Complex',3:ncol(to_plot)])/8)
  
}

# plotting the raw results
to_plot <- to_plot %>% melt(id.vars=c('N', 'Type', 'DGP'))

ggplot(to_plot) + geom_point(aes(x=N, y=value, color=Type)) + 
  geom_line(aes(x=N, y=value, color=Type)) + 
  facet_wrap(~variable + DGP, ncol=2, scales='free_y') + theme_minimal() + 
  theme(axis.text.x = element_text(angle = 45, hjust=1, vjust=1.5)) + ylab('Mean Squared Error')

# plotting coverage
coverage_df <- data.frame('N'=numeric(), 'Type'=character(), 'Type'=character(), 'Coverage'=numeric())

for(i in 1:length(to_save)){
  coverage_df <- rbind(coverage_df,
                       data.frame('N'=to_save[[i]]$N, 'DGP'=to_save[[i]]$Type, 'Type'='Corrected', 'Coverage'=to_save[[i]]$coverage$corrected),
                       data.frame('N'=to_save[[i]]$N, 'DGP'=to_save[[i]]$Type, 'Type'='Uncorrected', 'Coverage'=to_save[[i]]$coverage$uncorrected))
}

library(plyr)
coverage_df$DGP <- revalue(coverage_df$DGP, c("simple"="Simple", "complex"="Complex"))
coverage_df$Coverage <- coverage_df$Coverage/8
to_plot_coverage <- coverage_df %>% group_by(N, DGP, Type) %>% summarise_all(mean)

ggplot(to_plot_coverage) + geom_point(aes(x=N, y=Coverage, color=Type)) + 
  geom_line(aes(x=N, y=Coverage, color=Type)) + 
  facet_wrap(~DGP, ncol=2) + theme_minimal() + 
  theme(axis.text.x = element_text(angle = 45, hjust=1, vjust=1.5),
        axis.title.y = element_blank())
